package org.jvm.rtda.heap.classLoader;

import org.jvm.classfile.ClassFile;
import org.jvm.rtda.Object;
import org.jvm.rtda.*;
import org.jvm.rtda.heap.*;
import org.jvm.rtda.heap.classmember.*;
import org.jvm.rtda.heap.stringpool.StringPool;
import org.jvm.rtda.thread.Thread;

import java.util.*;

/**
 * 抽象的类加载器
 * 类加载器分为两类。分别为JVM直接实现的根类加载器和其他所有java代码实现的类加载器
 *
 * @author 海燕
 * @date 2023/3/23
 */
public abstract class AbstractKlassLoader {
    /**
     * 每个类加载器有自己独立的方法区，管理自己加载的类
     */
    protected Map<String, Klass> klassMap = new HashMap<>();

    /**
     * 本类加载器对应的javaClassLoader实例，对于根类加载器来说这个字段是null
     */
    protected Object jcl;

    public AbstractKlassLoader(Object jcl) {
        this.jcl = jcl;
    }

    /**
     * 寻找当前类加载器已经加载过的类
     * 注意这里是纯JVM实现的逻辑，不使用javaCall
     *
     * @param klassName
     * @return
     */
    public Klass findLoadedKlass(String klassName) {
        return klassMap.get(klassName);
    }


    /**
     * 定义一个类，并放入本加载器的方法区
     * 对于bootClassLoader来说，这里也只执行纯JVM逻辑不使用javaCall
     * 对于其他classLoader，进行父类加载时需要进行javaCall。
     *
     * @param klassName
     * @param klassData 类的二进制数据
     * @param thread    javaCall线程
     * @return
     */
    public Klass defineKlass(String klassName, byte[] klassData, Thread thread) {
        if (klassData == null) {
            return null;
        }
        //解析二进制文件
        ClassFile classFile = new ClassFile(klassData);
        //加载class对象
        Klass klass = new Klass(classFile, this);
        hackClass(klass);
        //加载父类
        resolveSuperClass(klass, thread);
        //加载类所实现的所有接口
        resolveInterfaces(klass, thread);
        link(klass);
        fillJavaClassAndJavaClassLoader(klass);
        klassMap.put(klassName, klass);
        return klass;
    }

    /**
     * 填充Klass中jClass和jClass中对应的jClassLoader
     *
     * @param klass
     */
    protected void fillJavaClassAndJavaClassLoader(Klass klass) {
        Klass classKlass = KlassLoaderRegister.getBootKlassLoader().findLoadedKlass("java/lang/Class");
        if (classKlass != null) {
            //类对象是java/lang/Class的实例
            Object classObject = classKlass.newObject();
            //类对象和类互相持有对方的引用
            classObject.setExtra(klass);
            klass.setjClass(classObject);
            classObject.setRefVar("classLoader", "Ljava/lang/ClassLoader;", this.jcl);
        }
    }

    /**
     * 链接
     *
     * @param klass
     */
    private void link(Klass klass) {
        verify(klass);
        prepare(klass);
    }

    /**
     * 进行class验证
     * 本虚拟机省略此步骤
     *
     * @param klass
     */
    private void verify(Klass klass) {
    }

    /**
     * 类链接准备
     *
     * @param klass
     */
    private void prepare(Klass klass) {
        //计算实例字段的数量和坐标
        calcInstanceFieldSlotIds(klass);
        //计算静态字段的数量和坐标
        calcStaticFieldSlotIds(klass);
        //对静态字段进行空间申请和初始化
        allocAndInitStaticVars(klass);
    }

    /**
     * 给类静态字段分配空间并赋初始值
     * 对于static与final修饰符所共同修饰的基本类型与字符串字段，从常量池获取初始值
     *
     * @param klass
     */
    private void allocAndInitStaticVars(Klass klass) {
        Slots staticVars = new Slots(klass.getStaticSlotCount());
        klass.setStaticVars(staticVars);
        Arrays.stream(klass.getFields())
                .filter(field -> field.isFinal() && field.isStatic() && field.getConstValueIndex() > 0)
                .forEach(field -> {
                    ConstantPool constantPool = klass.getConstantPool();
                    int constValueIndex = field.getConstValueIndex();
                    //这里不考虑除字符串之外的引用是因为，类静态字段初始化是从.class文件常量池中取值，不可能有引用类型
                    switch (field.getDescriptor()) {
                        case "Z":
                        case "B":
                        case "C":
                        case "S":
                        case "I":
                            int val1 = (int) constantPool.getConstant(constValueIndex);
                            staticVars.setInt(field.getSlotId(), val1);
                            break;
                        case "J":
                            long val2 = (long) constantPool.getConstant(constValueIndex);
                            staticVars.setLong(field.getSlotId(), val2);
                            break;
                        case "F":
                            float val3 = (float) constantPool.getConstant(constValueIndex);
                            staticVars.setFloat(field.getSlotId(), val3);
                            break;
                        case "D":
                            double val4 = (double) constantPool.getConstant(constValueIndex);
                            staticVars.setDouble(field.getSlotId(), val4);
                            break;
                        case "Ljava/lang/String":
                            String val5 = (String) constantPool.getConstant(constValueIndex);
                            staticVars.setRef(field.getSlotId(), StringPool.jString(val5));
                    }
                });
    }

    /**
     * 计算本类所有静态字段的数量和坐标
     * 注意字段数量是包括本类的所有父类的静态字段数量
     *
     * @param klass
     */
    private void calcStaticFieldSlotIds(Klass klass) {
        int slotId = 0;
        if (klass.getSuperClass() != null) {
            slotId = klass.getSuperClass().getStaticSlotCount();
        }
        for (Field field : klass.getFields()) {
            if (!field.isStatic()) {
                continue;
            }
            field.setSlotId(slotId++);
            if (field.isLongOrDouble()) {
                slotId++;
            }
        }
        klass.setStaticSlotCount(slotId);
    }


    /**
     * 计算本类所有实例字段的数量和坐标
     * 注意字段数量是包括本类的所有父类的实例字段数量
     *
     * @param klass
     */
    private void calcInstanceFieldSlotIds(Klass klass) {
        int slotId = 0;
        //如果存在父类，则需要计入父类中所有实例字段的数量
        if (klass.getSuperClass() != null) {
            slotId = klass.getSuperClass().getInstanceSlotCount();
        }
        for (Field field : klass.getFields()) {
            if (field.isStatic()) {
                continue;
            }
            field.setSlotId(slotId++);
            if (field.isLongOrDouble()) {
                slotId++;
            }
        }
        klass.setInstanceSlotCount(slotId);
    }

    /**
     * 加载本类实现的所有接口
     *
     * @param klass
     */
    private void resolveInterfaces(Klass klass, Thread thread) {
        Klass[] interfaces = new Klass[klass.getInterfaceNames().length];
        klass.setInterfaces(interfaces);
        for (int i = 0; i < klass.getInterfaceNames().length; i++) {
            interfaces[i] = loadKlass(klass.getInterfaceNames()[i], thread);
        }
    }


    private void hackClass(Klass klass) {
        if (klass.getName().equals("java/lang/ClassLoader")) {
            Method loadLibrary = klass.getStaticMethod("loadLibrary", "(Ljava/lang/Class;Ljava/lang/String;Z)V");
            loadLibrary.setCode(new byte[]{(byte) 0xb1}); // return void
        }
    }

    /**
     * 递归加载本类的所有父类，加载过程中这些父类都会被放进方法区
     *
     * @param klass
     */
    private void resolveSuperClass(Klass klass, Thread thread) {
        //java/lang/Object是所有类的父类且本身没有父类
        if (klass.getName().equals("java/lang/Object")) {
            return;
        }
        klass.setSuperClass(loadKlass(klass.getSuperClassName(), thread));
    }

    /**
     * 根据类名加载一个类
     * 对于bootClassLoader来说，这里也只执行纯JVM逻辑不使用javaCall
     * 对于其他classLoader，使用javaCall进行类加载
     *
     * @param klassName
     * @param thread
     * @return
     */
    public abstract Klass loadKlass(String klassName, Thread thread);

    public Object getJcl() {
        return jcl;
    }
}